import json
import openai
from tqdm import tqdm
import pandas as pd
import argparse
import os
import sys
from typing import Dict, List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for parallel execution."""
    parser = argparse.ArgumentParser(
        description="Run zero-shot persuasion-strategy classification over a slice of the dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="persuasion_results", help="Directory to write JSON results.")
    parser.add_argument("--csv_path", type=str, required=True, help="Path to the input CSV with columns video_id,story")
    return parser.parse_args()

# ---------------------------------------------------------------------------
# Zero-shot classification system prompt
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = (
    "You will be given a dictionary called persuasion_vocab that lists "
    "persuasion strategies with their definitions. You will also be given the STORY "
    "text of a video advertisement. Your task is to choose the SINGLE most relevant persuasion strategy key from "
    "persuasion_vocab that is most central to how the advertisement seeks to persuade viewers. "
    "Output ONLY the strategy key, nothing else."
)
# Persuasion vocabulary
topics = "Persuasion Strategies Vocabulary: { 'Authority':'Authority indicated through expertise, source of power, third-party approval, credentials, and awards','Social Identity':'Normative influence, which involves conformity with the positive expectations of another (person or group), using the idea of everyone else is doing it to influence behavior.','Social Proof':'Use of testimonials, reviews, or other social validation to demonstrate popularity or trustworthiness, increasing confidence and encouraging purchase.','Reciprocity':'Creates a future obligation by giving something first, prompting repayment, often unequal.','Foot in the door':'Starts with small requests followed by larger ones to facilitate compliance while maintaining cognitive coherence.','Overcoming Reactance':'Reduces resistance by postponing consequences, focusing on realistic concerns, forewarning, acknowledging resistance, raising self-esteem or efficacy.','Concreteness':'Uses specific, tangible details or examples to make an abstract concept concrete and relatable, creating vivid impressions.','Anchoring and Comparison':'Uses a reference point (anchor) or side-by-side comparisons to influence perceptions of value or superiority.','Social Impact':'Highlights a positive effect on society or community, promoting social causes or positive behaviors.','Scarcity':'Increases perceived value when availability is limited due to psychological reactance or rarity heuristic.','Unclear':'Strategy is unclear, not in English, or no persuasive strategy is central to the advertisement'}"

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # Load Qwen3-32B model for local inference
    global model, tokenizer
    model_name = "Qwen/Qwen3-32B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True,
    )

    # Load CSV data
    try:
        df = pd.read_csv(args.csv_path)
    except Exception as e:
        print(f"Error reading CSV {args.csv_path}: {e}")
        sys.exit(1)

    all_records = df.to_dict(orient='records')

    # Determine slice for this run
    start_idx = args.start
    end_idx = len(all_records) - 1 if args.end is None else min(args.end, len(all_records) - 1)
    slice_records = all_records[start_idx : end_idx + 1]

    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    results = []
    output_path = os.path.join(args.output_dir, f"persuasion_results_{start_idx}_{end_idx}.json")

    for rec in tqdm(slice_records, desc=f"Persuasion Eval {start_idx}-{end_idx}"):
        try:
            video_id = str(rec.get('video_id', '')).strip()
            story_text = rec.get('story', '')
            cleaned_text = ' '.join(str(story_text).split()).replace('\n', '').replace('\f', '')

            # Build zero-shot prompt
            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": f"{topics}\n\nStory: {cleaned_text}"},
            ]

            try:
                # Qwen inference
                input_ids = tokenizer.apply_chat_template(
                    messages,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                    enable_thinking=False,
                ).to(model.device)

                with torch.no_grad():
                    outputs = model.generate(
                        input_ids=input_ids,
                        max_new_tokens=300,
                        temperature=0.85,
                        do_sample=True,
                        min_p=0.1,
                    )

                pred_topic_raw = tokenizer.decode(
                    outputs[0][len(input_ids[0]):], skip_special_tokens=True
                ).strip()
                pred_topic = pred_topic_raw.lower().strip("'\". ,")
            except Exception as e:
                print(f"Error during Qwen inference for key {video_id}: {e}")
                pred_topic = "error_inference"

            # Store results
            result_item = {
                'video_id': video_id,
                'url': f"https://www.youtube.com/watch?v={video_id}" if video_id else "",
                'story': cleaned_text,
                'predicted_topic': pred_topic,
            }
            results.append(result_item)
            
            # Incremental save
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=4)

        except Exception as e:
            print(f"Error processing key {video_id}: {e}")
            continue

    print(f"Finished processing. Results saved to {output_path}")

if __name__ == "__main__":
    main()




